{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1d913f3-3c42-428e-9cc0-8d679d51897a",
   "metadata": {},
   "source": [
    "# PEFT 库 QLoRA 实战 - ChatGLM3-6B\n",
    "\n",
    "通常，模型被量化后不会进一步训练用于下游任务，因为由于权重和激活的较低精度，训练可能不稳定。\n",
    "\n",
    "但是由于PEFT方法只添加额外的可训练参数，这使得我们可以使用PEFT适配器（Adapter）来训练一个量化模型！将量化与PEFT结合起来可以成为在单个GPU上训练大模型的微调策略。\n",
    "\n",
    "例如，`QLoRA` 是一种将模型量化为4位然后使用LoRA进行训练的方法，使得在单个16GB GPU（本教程以 NVIDIA T4为例）上微调一个具有65B参数的大模型成为可能。\n",
    "\n",
    "THUDM Hugging Face 主页：https://huggingface.co/THUDM\n",
    "\n",
    "## 教程说明\n",
    "\n",
    "本教程使用 QLoRA 论文中介绍的量化技术：`NF4 数据类型`、`双量化` 和 `混合精度计算`，在 `ChatGLM3-6b` 模型上实现了 QLoRA 微调。并展示了完整的 QLoRA 微调流程，具体如下：\n",
    "\n",
    "- 数据准备\n",
    "    - 下载数据集\n",
    "    - 设计 Tokenizer 函数处理样本（map、shuffle、flatten）\n",
    "    - 自定义批量数据处理类 DataCollatorForChatGLM\n",
    "- 训练模型\n",
    "    - 加载 ChatGLM3-6B 量化模型\n",
    "    - PEFT 量化模型预处理（prepare_model_for_kbit_training）\n",
    "    - QLoRA 适配器配置（TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING）\n",
    "    - 微调训练超参数配置（TrainingArguments）\n",
    "    - 开启训练（trainer.train)\n",
    "    - 保存QLoRA模型（trainer.model.save_pretrained)\n",
    "- [模型推理](peft_chatglm_inference.ipynb)\n",
    "    - 加载 ChatGLM3-6B 基础模型\n",
    "    - 加载 ChatGLM3-6B QLoRA 模型（PEFT Adapter）\n",
    "    - 微调前后对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7fa8105c-6dda-426b-9180-ab9abbc9ce9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义全局变量和参数\n",
    "model_name_or_path = 'THUDM/chatglm3-6b'  # 模型ID或本地路径\n",
    "train_data_path = 'HasturOfficial/adgen'    # 训练数据路径\n",
    "eval_data_path = None                     # 验证数据路径，如果没有则设置为None\n",
    "seed = 8                                 # 随机种子\n",
    "max_input_length = 512                    # 输入的最大长度\n",
    "max_output_length = 1536                  # 输出的最大长度\n",
    "lora_rank = 4                             # LoRA秩\n",
    "lora_alpha = 32                           # LoRA alpha值\n",
    "lora_dropout = 0.05                       # LoRA Dropout率\n",
    "resume_from_checkpoint = None             # 如果从checkpoint恢复训练，指定路径\n",
    "prompt_text = ''                          # 所有数据前的指令文本\n",
    "compute_dtype = 'fp32'                    # 计算数据类型（fp32, fp16, bf16）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93eab798-4ec0-45f0-af87-a3aa880f0888",
   "metadata": {},
   "source": [
    "## 数据准备\n",
    "\n",
    "### 下载数据集\n",
    "\n",
    "从 Hugging Face 加载 adgen 数据集，并tokenize，shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f7517070-eae3-45e1-b6a2-f49f332650ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(train_data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cbd03e18-05b7-47c3-bb32-520fb6f0bef3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['content', 'summary'],\n",
       "        num_rows: 114599\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['content', 'summary'],\n",
       "        num_rows: 1070\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea17c452-be25-4424-884d-14ba92c17b79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import ClassLabel, Sequence\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7b4ee2cc-e964-4a47-9b18-67f99d65ff73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>content</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>类型#裙*颜色#蓝色*风格#青春*风格#清新*风格#性感*裙领型#一字领</td>\n",
       "      <td>露出性感的锁骨，展现柔美女人味道，这款裙摆做的很不错，它拥有一字领的设计，呈现出优雅大方的气质。细碎的花朵图案，满满的铺陈在蓝色的基调上，让清新的蓝色，焕发出青春的活力，充满勃勃生机。半透明质感的肩部和下摆，若影若现，飘然出尘，演绎仙女般的自然灵动。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>类型#裙*图案#印花*裙下摆#荷叶边*裙衣门襟#系带</td>\n",
       "      <td>不要对自己的身材不自信，这一款裙子就是要让每一个女生都能拥有自己的风采，骨子里散发出来的好气质。领部系带，给人一种温婉雅致的舒适感。别致印花点缀，简洁明了更精致耐看。荷叶边裙摆，大气耐看显气质彰显女神范。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>类型#裙*颜色#纯色*图案#纯色*裙长#连衣裙*裙领型#立领*裙款式#拼接</td>\n",
       "      <td>这款连衣裙，采用的是纯色的搭配，纯黑配色满满的深沉气质，就如同高级灰配色一样百搭。而表面的双材质拼接设计，丰富层次，却并不显花哨。精致的立领设计无形中彰显了优雅大气的感觉。随性搭配，尽显潮女气质。</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(dataset[\"train\"], num_examples=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "489ccc1b-3820-4f72-934d-3bb0875de954",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b875200f-7e13-46d8-954d-bfaa76829769",
   "metadata": {},
   "source": [
    "### 使用 ChatGLM3-6b Tokenizer 处理数据\n",
    "\n",
    "\n",
    "关于 `ignore_label_id` 的设置：\n",
    "\n",
    "在许多自然语言处理和机器学习框架中，`ignore_label_id` 被设置为 -100 是一种常见的约定。这个特殊的值用于标记在计算损失函数时应该被忽略的目标标签。让我们详细了解一下这个选择的原因：\n",
    "\n",
    "1. **损失函数忽略特定值**：训练语言模型时，损失函数（例如交叉熵损失）通常只计算对于模型预测重要或关键的标签的损失。在某些情况下，你可能不希望某些标签对损失计算产生影响。例如，在序列到序列的模型中，输入部分的标签通常被设置为一个忽略值，因为只有输出部分的标签对于训练是重要的。\n",
    "\n",
    "2. **为何选择-100**：这个具体的值是基于实现细节选择的。在 PyTorch 的交叉熵损失函数中，可以指定一个 `ignore_index` 参数。当损失函数看到这个索引值时，它就会忽略对应的输出标签。使用 -100 作为默认值是因为它是一个不太可能出现在标签中的数字（特别是在处理分类问题时，标签通常是从0开始的正整数）。\n",
    "\n",
    "3. **标准化和通用性**：由于这种做法在多个库和框架中被采纳，-100 作为忽略标签的默认值已经变得相对标准化，这有助于维护代码的通用性和可读性。\n",
    "\n",
    "总的来说，将 `ignore_label_id` 设置为 -100 是一种在计算损失时排除特定标签影响的便捷方式。这在处理特定类型的自然语言处理任务时非常有用，尤其是在涉及序列生成或修改的任务中。\n",
    "\n",
    "#### 关于 ChatGLM3 的填充处理说明\n",
    "\n",
    "- input_id（query）里的填充补全了输入长度，目的是不改变原始文本的含义。\n",
    "- label（answer）里的填充会用来跟模型基于 query 生成的结果计算 Loss，为了不影响损失值计算，也需要设置。咱们计算损失时，是针对 answer 部分的 Embedding Vector，因此 label 这样填充，前面的序列就自动忽略掉了，只比较生成内容的 loss。因此，需要将answer前面的部分做忽略填充。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "42e3c1f1-1e5b-4452-babe-234c8ae15bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "# revision='b098244' 版本对应的 ChatGLM3-6B 设置 use_reentrant=False\n",
    "# 最新版本 use_reentrant 被设置为 True，会增加不必要的显存开销\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,\n",
    "                                          trust_remote_code=True,\n",
    "                                          revision='b098244')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2cc4a0fd-3239-4cb4-bf3f-8d8fcfff9af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize_func 函数\n",
    "def tokenize_func(example, tokenizer, ignore_label_id=-100):\n",
    "    \"\"\"\n",
    "    对单个数据样本进行tokenize处理。\n",
    "\n",
    "    参数:\n",
    "    example (dict): 包含'content'和'summary'键的字典，代表训练数据的一个样本。\n",
    "    tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。\n",
    "    ignore_label_id (int, optional): 在label中用于填充的忽略ID，默认为-100。\n",
    "\n",
    "    返回:\n",
    "    dict: 包含'tokenized_input_ids'和'labels'的字典，用于模型训练。\n",
    "    \"\"\"\n",
    "\n",
    "    # 构建问题文本\n",
    "    question = prompt_text + example['content']\n",
    "    if example.get('input', None) and example['input'].strip():\n",
    "        question += f'\\n{example[\"input\"]}'\n",
    "\n",
    "    # 构建答案文本\n",
    "    answer = example['summary']\n",
    "\n",
    "    # 对问题和答案文本进行tokenize处理\n",
    "    q_ids = tokenizer.encode(text=question, add_special_tokens=False)\n",
    "    a_ids = tokenizer.encode(text=answer, add_special_tokens=False)\n",
    "\n",
    "    # 如果tokenize后的长度超过最大长度限制，则进行截断\n",
    "    if len(q_ids) > max_input_length - 2:  # 保留空间给gmask和bos标记\n",
    "        q_ids = q_ids[:max_input_length - 2]\n",
    "    if len(a_ids) > max_output_length - 1:  # 保留空间给eos标记\n",
    "        a_ids = a_ids[:max_output_length - 1]\n",
    "\n",
    "    # 构建模型的输入格式\n",
    "    input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)\n",
    "    question_length = len(q_ids) + 2  # 加上gmask和bos标记\n",
    "\n",
    "    # 构建标签，对于问题部分的输入使用ignore_label_id进行填充\n",
    "    labels = [ignore_label_id] * question_length + input_ids[question_length:]\n",
    "\n",
    "    return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d576b3fa-156a-4486-a267-ad773172d90b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a7cda56-a96a-4787-8db9-a712c82e2bb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f78a19ed-6862-42fa-896d-2c297b94ce41",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = dataset['train'].column_names\n",
    "tokenized_dataset = dataset['train'].map(\n",
    "    lambda example: tokenize_func(example, tokenizer),\n",
    "    batched=False, \n",
    "    remove_columns=column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6fe0b50d-56cf-469f-9b35-0eb573cd1b04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input_ids</th>\n",
       "      <th>labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[64790, 64792, 30910, 33467, 31010, 56778, 30998, 33692, 31010, 35798, 30998, 32799, 31010, 37785, 30998, 56778, 54888, 31010, 56069, 54762, 56778, 30998, 56778, 54888, 31010, 30913, 54952, 30998, 56778, 56278, 54888, 31010, 54589, 56278, 30998, 56778, 40877, 31010, 54535, 33257, 30910, 54999, 55583, 54617, 43518, 34319, 54530, 56778, 56532, 31123, 55881, 55640, 37411, 54530, 35798, 46839, 31123, 32985, 31841, 51827, 37785, 55810, 32789, 34574, 32799, 31155, 54589, 56278, 30913, 54952, 56778, 55090, 54888, 31123, 55097, 54589, 34319, 32938, 33253, 33355, 56383, 58190, 54557, 55108, 54625, 47745, 56158, 54888, 31155, 54535, 33257, 56069, 54762, 56778, 31002, 5234, 30984, 30994, 55353, 51827, ...]</td>\n",
       "      <td>[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 54999, 55583, 54617, 43518, 34319, 54530, 56778, 56532, 31123, 55881, 55640, 37411, 54530, 35798, 46839, 31123, 32985, 31841, 51827, 37785, 55810, 32789, 34574, 32799, 31155, 54589, 56278, 30913, 54952, 56778, 55090, 54888, 31123, 55097, 54589, 34319, 32938, 33253, 33355, 56383, 58190, 54557, 55108, 54625, 47745, 56158, 54888, 31155, 54535, 33257, 56069, 54762, 56778, 31002, 5234, 30984, 30994, 55353, 51827, ...]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(tokenized_dataset, num_examples=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b28ee0-2207-41c8-a4cf-166d6c2b7a42",
   "metadata": {},
   "source": [
    "### 数据集处理：shuffle & flatten \n",
    "\n",
    "洗牌(shuffle)会将数据集的索引列表打乱，以创建一个索引映射。\n",
    "\n",
    "然而，一旦您的数据集具有索引映射，速度可能会变慢10倍。这是因为需要额外的步骤来使用索引映射获取要读取的行索引，并且最重要的是，您不再连续地读取数据块。\n",
    "\n",
    "要恢复速度，需要再次使用 Dataset.flatten_indices()将整个数据集重新写入磁盘上，从而删除索引映射。\n",
    "\n",
    "ref: https://huggingface.co/docs/datasets/v2.15.0/en/package_reference/main_classes#datasets.Dataset.flatten_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "75293898-f010-4cae-bf01-d58afc9dc214",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenized_dataset.shuffle(seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "99f17038-ec1a-4935-ae39-afca620c461d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenized_dataset.flatten_indices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5fc7921-95b1-45b5-a88e-2a80f9b25ceb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "281bac4d-4a57-4de3-991e-d62fbe15b6af",
   "metadata": {},
   "source": [
    "### 定义 DataCollatorForChatGLM 类 批量处理数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "631a70f9-a75e-42d6-ad8b-5fd635db8d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import List, Dict, Optional\n",
    "\n",
    "# DataCollatorForChatGLM 类\n",
    "class DataCollatorForChatGLM:\n",
    "    \"\"\"\n",
    "    用于处理批量数据的DataCollator，尤其是在使用 ChatGLM 模型时。\n",
    "\n",
    "    该类负责将多个数据样本（tokenized input）合并为一个批量，并在必要时进行填充(padding)。\n",
    "\n",
    "    属性:\n",
    "    pad_token_id (int): 用于填充(padding)的token ID。\n",
    "    max_length (int): 单个批量数据的最大长度限制。\n",
    "    ignore_label_id (int): 在标签中用于填充的ID。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):\n",
    "        \"\"\"\n",
    "        初始化DataCollator。\n",
    "\n",
    "        参数:\n",
    "        pad_token_id (int): 用于填充(padding)的token ID。\n",
    "        max_length (int): 单个批量数据的最大长度限制。\n",
    "        ignore_label_id (int): 在标签中用于填充的ID，默认为-100。\n",
    "        \"\"\"\n",
    "        self.pad_token_id = pad_token_id\n",
    "        self.ignore_label_id = ignore_label_id\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"\n",
    "        处理批量数据。\n",
    "\n",
    "        参数:\n",
    "        batch_data (List[Dict[str, List]]): 包含多个样本的字典列表。\n",
    "\n",
    "        返回:\n",
    "        Dict[str, torch.Tensor]: 包含处理后的批量数据的字典。\n",
    "        \"\"\"\n",
    "        # 计算批量中每个样本的长度\n",
    "        len_list = [len(d['input_ids']) for d in batch_data]\n",
    "        batch_max_len = max(len_list)  # 找到最长的样本长度\n",
    "\n",
    "        input_ids, labels = [], []\n",
    "        for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):\n",
    "            pad_len = batch_max_len - len_of_d  # 计算需要填充的长度\n",
    "            # 添加填充，并确保数据长度不超过最大长度限制\n",
    "            ids = d['input_ids'] + [self.pad_token_id] * pad_len\n",
    "            label = d['labels'] + [self.ignore_label_id] * pad_len\n",
    "            if batch_max_len > self.max_length:\n",
    "                ids = ids[:self.max_length]\n",
    "                label = label[:self.max_length]\n",
    "            input_ids.append(torch.LongTensor(ids))\n",
    "            labels.append(torch.LongTensor(label))\n",
    "\n",
    "        # 将处理后的数据堆叠成一个tensor\n",
    "        input_ids = torch.stack(input_ids)\n",
    "        labels = torch.stack(labels)\n",
    "\n",
    "        return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "790fb86b-7662-4c49-8e3b-3db239c60838",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据整理器\n",
    "data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4124a3-e2a5-4e68-828b-3db3590d6f04",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f280438-3192-405d-a7bc-095831af197f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0a223fcf-630a-4978-ad96-bfe93c6581ea",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "### 加载 ChatGLM3-6B 量化模型\n",
    "\n",
    "使用 `nf4` 量化数据类型加载模型，开启双量化配置，以`bf16`混合精度训练，预估显存占用接近4GB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a1480790-6035-4212-bc64-4292f6a908ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModel, BitsAndBytesConfig\n",
    "\n",
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd8738b8-f510-4d9c-b789-d5afda67e39b",
   "metadata": {},
   "source": [
    "### 加载模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c6ab5f48-a7ce-492a-bb8b-4c0c3ff04bd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a24fa9650544ddd9019b760ff34e41a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# revision='b098244' 版本对应的 ChatGLM3-6B 设置 use_reentrant=False\n",
    "# 最新版本 use_reentrant 被设置为 True，会增加不必要的显存开销\n",
    "model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                  quantization_config=q_config,\n",
    "                                  device_map='auto',\n",
    "                                  trust_remote_code=True,\n",
    "                                  revision='b098244')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "aaabfe10-d762-498a-87f1-1506d2ea2fe9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3739.69MiB\n"
     ]
    }
   ],
   "source": [
    "# 获取当前模型占用的 GPU显存（差值为预留给 PyTorch 的显存）\n",
    "memory_footprint_bytes = model.get_memory_footprint()\n",
    "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2)  # 转换为 MiB\n",
    "\n",
    "print(f\"{memory_footprint_mib:.2f}MiB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b58f580-34c9-4f7f-b4cb-86548f25f2b1",
   "metadata": {},
   "source": [
    "### 预处理量化模型\n",
    "\n",
    "预处理量化后的模型，使其可以支持低精度微调训练\n",
    "\n",
    "ref: https://huggingface.co/docs/peft/main/en/developer_guides/quantization#quantize-a-model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5f59a712-4829-4929-b795-31ffdfab1761",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n"
     ]
    }
   ],
   "source": [
    "from peft import TaskType, LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "\n",
    "kbit_model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21649f5e-3f31-44ff-ab0b-82ef1614491a",
   "metadata": {},
   "source": [
    "### 自定义模型新增 Adapter \n",
    "\n",
    "当新的热门 transformer 网络架构（新模型）发布时，Huggingface 社区会尽力快速将它们添加到PEFT中。\n",
    "\n",
    "如果是 Hugging Face Transformers 库还未内置支持的模型，可以使用自定义模型的方式进行配置。\n",
    "\n",
    "具体来说，在初始化相应的微调配置类（例如`LoraConfig`）时，我们需要显式指定在哪些层新增适配器（Adapter），并将其设置正确。\n",
    "\n",
    "ref: https://huggingface.co/docs/peft/developer_guides/custom_models\n",
    "\n",
    "\n",
    "#### PEFT 适配模块设置\n",
    "\n",
    "\n",
    "在PEFT库的 [constants.py](https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py) 文件中定义了不同的 PEFT 方法，在各类大模型上的微调适配模块。\n",
    "\n",
    "通常，名称相同的模型架构也类似，应用微调方法时的适配器设置也几乎一致。\n",
    "\n",
    "例如，如果新模型架构是`mistral`模型的变体，并且您想应用 LoRA 微调。在 TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING中`mistral`包含[\"q_proj\", \"v_proj\"]。\n",
    "\n",
    "这表示说，对于`mistral`模型，LoRA 的 target_modules 通常是 [\"q_proj\", \"v_proj\"]。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b4dd1cb6-c1b4-430b-b031-c842ddc683a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING\n",
    "\n",
    "target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "517018b1-dff0-4caf-a1ff-21933f95e4b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['query_key_value']"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_modules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d8b7c08-931e-4764-9bfb-4230bc825561",
   "metadata": {},
   "source": [
    "### LoRA 适配器配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ab44290b-7bf2-44f0-ac36-ef0d41527224",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    target_modules=target_modules,\n",
    "    r=lora_rank,\n",
    "    lora_alpha=lora_alpha,\n",
    "    lora_dropout=lora_dropout,\n",
    "    bias='none',\n",
    "    inference_mode=False,\n",
    "    task_type=TaskType.CAUSAL_LM\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bee2f80a-ccac-4c85-965e-ce5e5be528e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "qlora_model = get_peft_model(kbit_model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9cf0acae-9e20-4d30-9999-53c3079cc42c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 974,848 || all params: 6,244,558,848 || trainable%: 0.01561115883009451\n"
     ]
    }
   ],
   "source": [
    "qlora_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d84f15b-93e1-4d90-80b9-ead28cd98406",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d8043c51-fa42-4867-9382-103ad5b082f3",
   "metadata": {},
   "source": [
    "### 训练超参数配置\n",
    "\n",
    "- 1个epoch表示对训练集的所有样本进行一次完整的训练。\n",
    "- `num_train_epochs` 表示要完整进行多少个 epochs 的训练。\n",
    "\n",
    "#### 关于使用 num_train_epochs 时，训练总步数 `steps` 的计算方法\n",
    "\n",
    "- 训练总步数： `total_steps = steps/epoch * num_train_epochs` \n",
    "- 每个epoch的训练步数：`steps/epoch = num_train_examples / (batch_size * gradient_accumulation_steps)`\n",
    "\n",
    "\n",
    "**以 `adgen` 数据集为例计算**\n",
    "\n",
    "```json\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['content', 'summary'],\n",
    "        num_rows: 114599\n",
    "    })\n",
    "    validation: Dataset({\n",
    "        features: ['content', 'summary'],\n",
    "        num_rows: 1070\n",
    "    })\n",
    "})\n",
    "```\n",
    "\n",
    "代入超参数和配置进行计算：\n",
    "\n",
    "```python\n",
    "num_train_epochs = 1\n",
    "num_train_examples = 114599\n",
    "batch_size = 16\n",
    "gradient_accumulation_steps = 4\n",
    "\n",
    "\n",
    "steps = num_train_epochs * num_train_examples / (batch_size * gradient_accumulation_steps)\n",
    "      = 1 * 114599 / (16 * 4)\n",
    "      = 1790\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "aa527b7e-2dd6-46f0-ba7f-49f1430695a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"models/{model_name_or_path}\",          # 输出目录\n",
    "    per_device_train_batch_size=16,                     # 每个设备的训练批量大小\n",
    "    gradient_accumulation_steps=4,                     # 梯度累积步数\n",
    "    # per_device_eval_batch_size=8,                      # 每个设备的评估批量大小\n",
    "    learning_rate=1e-3,                                # 学习率\n",
    "    num_train_epochs=1,                                # 训练轮数\n",
    "    lr_scheduler_type=\"linear\",                        # 学习率调度器类型\n",
    "    warmup_ratio=0.1,                                  # 预热比例\n",
    "    logging_steps=10,                                 # 日志记录步数\n",
    "    save_strategy=\"steps\",                             # 模型保存策略\n",
    "    save_steps=100,                                    # 模型保存步数\n",
    "    # evaluation_strategy=\"steps\",                       # 评估策略\n",
    "    # eval_steps=500,                                    # 评估步数\n",
    "    optim=\"adamw_torch\",                               # 优化器类型\n",
    "    fp16=True,                                        # 是否使用混合精度训练\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "563f6592-1dac-433e-8195-c24761350cd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "        model=qlora_model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_dataset,\n",
    "        data_collator=data_collator\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1ba1de6-6e37-42a1-9ff0-4e926dee5a74",
   "metadata": {},
   "source": [
    "#### 训练参数（用于演示）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "6164c2c8-47b4-4472-a192-952d8e425554",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_demo_args = TrainingArguments(\n",
    "    output_dir=f\"models/demo/{model_name_or_path}\",          # 输出目录\n",
    "    per_device_train_batch_size=16,                     # 每个设备的训练批量大小\n",
    "    gradient_accumulation_steps=4,                     # 梯度累积步数\n",
    "    learning_rate=1e-3,                                # 学习率\n",
    "    max_steps=100,                                     # 训练步数\n",
    "    lr_scheduler_type=\"linear\",                        # 学习率调度器类型\n",
    "    warmup_ratio=0.1,                                  # 预热比例\n",
    "    logging_steps=10,                                 # 日志记录步数\n",
    "    save_strategy=\"steps\",                             # 模型保存策略\n",
    "    save_steps=20,                                    # 模型保存步数\n",
    "    optim=\"adamw_torch\",                               # 优化器类型\n",
    "    fp16=True,                                        # 是否使用混合精度训练\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9655c7a7-75fd-4f0a-b190-e97a533c9bae",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "        model=qlora_model,\n",
    "        args=training_demo_args,\n",
    "        train_dataset=tokenized_dataset,\n",
    "        data_collator=data_collator\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7073aa2a-22f0-4ac1-ba93-5d8953ddb18f",
   "metadata": {},
   "source": [
    "### 开始训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "07413af1-dff0-466a-adaa-00a80f327350",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [100/100 46:45, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>4.694100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>3.805300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>3.590100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>3.502200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>3.440100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>3.411000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>3.399500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>3.375400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>3.359900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>3.328900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Checkpoint destination directory models/demo/THUDM/chatglm3-6b/checkpoint-20 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Checkpoint destination directory models/demo/THUDM/chatglm3-6b/checkpoint-40 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Checkpoint destination directory models/demo/THUDM/chatglm3-6b/checkpoint-60 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Checkpoint destination directory models/demo/THUDM/chatglm3-6b/checkpoint-80 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Checkpoint destination directory models/demo/THUDM/chatglm3-6b/checkpoint-100 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=100, training_loss=3.5906502532958986, metrics={'train_runtime': 2832.2774, 'train_samples_per_second': 2.26, 'train_steps_per_second': 0.035, 'total_flos': 3.840600180523008e+16, 'train_loss': 3.5906502532958986, 'epoch': 0.06})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "320cd8a8-3690-4a2b-b725-f4c1316646ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.model.save_pretrained(f\"models/demo/{model_name_or_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9909f94a-b993-41b4-99f2-95a3b8b38c14",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2111a492-d605-4b69-b98b-7db97e45f7f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
